#!/usr/bin/env python3

"""
Finite State Machine generator for the cozmo_fsm package.
Modeled after the Tekkotsu stateparser tool.

Usage:  genfsm [infile.fsm | -] [outfile.py | -]

Use '-' to indicate standard input or standard output.  If a
second argument is not supplied, writes to infile.py, or to
standard output if the input was '-'.

To enter state machine notation use a line that contains
just $setup ''', followed by the lines of the state machine,
and ending with a line contaning just '''. This will result
in a definition of a setup() method for the state node class
you are defining. Example:

  class MyNode(StateNode):
      $setup '''
          Say("Hello") =C=> Forward(50)
      '''

See the cozmo_fsm/examples directory for examples of .fsm files
and the .py files they generate.

Note: install termcolor package to get messages displayed in color.

Author: David S. Touretzky, Carnegie Mellon University
"""

import sys, time, re

try:
    import colorama
    colorama.init()
except: pass

try:
    from termcolor import cprint
except:
    def cprint(string, color=None, file=None):
        print(string)


class Token:
    def __repr__(self):
        return "<%s>" % self.__class__.__name__
    # Lexer tokens
    def isIdentifier(self): return isinstance(self,Identifier)
    def isColon(self): return isinstance(self,Colon)
    def isNewline(self): return isinstance(self,Newline)
    def isEqual(self): return isinstance(self,Equal)
    def isArrowHead(self): return isinstance(self,ArrowHead)
    def isComma(self): return isinstance(self,Comma)
    def isLBrace(self): return isinstance(self,LBrace)
    def isRBrace(self): return isinstance(self,RBrace)
    def isArglist(self): return isinstance(self,Arglist)
    # Parser stage 1 tokens
    def isLabelRef(self): return isinstance(self,LabelRef)
    def isLabelDef(self): return isinstance(self,LabelDef)
    def isConstructorCall(self): return isinstance(self,ConstructorCall)
    def isIdentifierList(self): return isinstance(self,IdentifierList)
    # Parser stage 2 token
    def isTransition(self) : return isinstance(self,Transition)
    # Parser stage 3 tokens
    def isNodeDefinition(self): return isinstance(self,NodeDefinition)

# Lexer tokens:

class Identifier(Token):
    def __init__(self,name):
        self.name = name
    def __repr__(self):
        return "<Identifier %s>" % self.name

class Colon(Token): pass
class Equal(Token): pass
class ArrowHead(Token) : pass
class Comma(Token) : pass
class LBrace(Token) : pass
class RBrace(Token) : pass
class Newline(Token) : pass

class Arglist(Token):
    def __init__(self,value):
        self.value = value
    def __repr__(self):
        return "<Arglist %s>" % self.value

# Parser stage 1 and 2 tokens:

class IdentifierList(Token):
    def __init__(self,label_refs):
        self.label_refs = label_refs
    def __repr__(self):
        return "<IdentifierList %s>" % ','.join(self.label_refs)

class LabelDef(Token):
    def __init__(self,label):
        self.label = label
    def __repr__(self):
        return "<LabelDef %s>" % self.label

class LabelRef(Token):
    def __init__(self,label):
        self.label = label
    def __repr__(self):
        return "<LabelRef %s>" % self.label

class ConstructorCall(Token):
    def __init__(self,name,arglist):
        self.name = name
        self.arglist = arglist
    def __repr__(self):
        return "<ConstructorCall %s%s>" % (self.name, self.arglist)

# Parser stage 3 tokens

class NodeDefinition(Token):
    def __init__(self,label,node_type,arglist):
        self.label = label
        self.node_type = node_type
        self.arglist = arglist
    def __repr__(self):
        label = self.label+':' if self.label else ''
        return "<NodeDefinition %s%s%s>" % \
               (label, self.node_type, self.arglist)

class Transition(Token):
    def __init__(self,label,trans_type,arglist):
        self.label = label
        self.trans_type = trans_type
        self.arglist = arglist
        self.sources = []
        self.destinations = []

    def __repr__(self):
        label = self.label+':' if self.label else ''
        if len(self.sources) == 1:
            srcs = self.sources[0]
        else:
            srcs = '{%s}' % ','.join(self.sources)
        if len(self.destinations) == 1:
            dests = self.destinations[0]
        else:
            dests = '{%s}' % ','.join(self.destinations)
        return "<Transition %s=%s%s=>%s>" % (srcs,label,self.trans_type,dests)

_lex_punc_table = \
  (
    (' ', None),
    ('\t', None),
    ('\n', Newline),
    ('\r\n', Newline),  # must precede '\r' in this list for line counting to work correctly
    ('\r', Newline),
    (':', Colon),
    (',', Comma),
    ('=>', ArrowHead),  # must precede Equal in this list
    ('=', Equal),
    ('{', LBrace),
    ('}', RBrace)
   )

current_line = 0

def handle_newline():
    global current_line
    current_line += 1 

def lexer (string):
    """Convert input string into a sequence of lexer tokens."""
    r_identifier = re.compile(r'((self\.)|)\w+')
    tokens = []
    while string:
        was_in_table = False
        for chars,tok in _lex_punc_table:
            if string[0:len(chars)] == chars:
                if tok:
                    this_token = tok()
                    tokens.append(this_token)
                    if this_token.isNewline(): handle_newline()
                string = string[len(chars):]
                was_in_table = True
                break
        if was_in_table: continue
        if string[0] == '#':
            string = lexer_skip_comment(string)
            continue
        if string[0] == '(':
            arglist, string = lexer_build_arglist(string)
            tokens.append(arglist)
            continue
        match_result = r_identifier.match(string)
        if match_result:
            endpos = match_result.span()[1]
            tokens.append(Identifier(string[0:endpos]))
            string = string[endpos:]
            continue
        # If we reach here, we've found something indigestible.
        report_line_error("syntax error at '%s'" % error_fragment(string))
        next_line = string.find('\n')
        if next_line > -1:
            string = string[next_line:]
            continue
        break
    return tokens

def lexer_skip_comment(string):
    """String begins with '#'. Skip everything from there to the first newline."""
    pos = string.find('\r')
    if pos == -1:
        pos = string.find('\n')
    if pos == -1:
        raise Exception('Missing newline at end of comment.')
    return string[pos:]

def lexer_build_arglist(string):
    """Helper for lexer.  Parses an argument list and returns it plus the remainder of the string."""
    ostring = string
    lookstack = [')']
    pos = 1
    while lookstack and pos < len(string):
        if lookstack[0] == string[pos]:
            del lookstack[0]
        elif lookstack[0] not in '\'"':
            if string[pos] == '(':
                lookstack.insert(0,')')
            elif string[pos] == '[':
                lookstack.insert(0,']')
            elif string[pos] == '{':
                lookstack.insert(0,'}')
            elif string[pos] in ')]}':
                break
        pos += 1
    if lookstack:
        cleanstr = string.strip()
        p = min(pos, len(cleanstr)-1)
        report_line_error("Ill-formed argument list at '%s' near '%s'" %
                          (error_fragment(ostring), cleanstr[p]))
        return Arglist(''), ''
    return Arglist(string[0:pos]), string[pos:]

def parser1(lex_tokens):
    """Assembles label-def / constructor-call / label-ref / identifier-list tokens."""
    p1tokens = []
    while lex_tokens:
        if lex_tokens[0].isIdentifier():
            # An identifier must be a label definition, constructor call, or label reference.
            if len(lex_tokens) > 1 and lex_tokens[1].isColon():   # Label definition
                p1tokens.append(LabelDef(lex_tokens[0].name))
                del lex_tokens[0:2]
                continue
            if len(lex_tokens) > 1 and lex_tokens[1].isArglist(): # Constructor call
                p1tokens.append(ConstructorCall(lex_tokens[0].name,lex_tokens[1].value))
                del lex_tokens[0:2]
                continue
            # Default case: identifier assumed to be a label reference.
            p1tokens.append(LabelRef(lex_tokens[0].name))
            del lex_tokens[0]
            continue
        # A left braces introduces a comma-separated list of label references
        if lex_tokens[0].isLBrace():
            del lex_tokens[0]
            label_refs = []
            need_comma = False
            need_rbrace = True
            while lex_tokens:
                if lex_tokens[0].isRBrace():
                    p1tokens.append(IdentifierList(label_refs))
                    del lex_tokens[0]
                    need_rbrace = False
                    break
                if need_comma and lex_tokens[0].isComma():
                    del lex_tokens[0]
                    need_comma = False
                elif not need_comma and lex_tokens[0].isIdentifier():
                    label_refs.append(lex_tokens[0].name)
                    del lex_tokens[0]
                    need_comma = True
                else:
                    report_line_error('Syntax error in identifier list near %s.' %lex_tokens[0])
                    del lex_tokens[0]
            if not label_refs:
                report_line_error('Empty identifier list {}.')
            if label_refs and not need_comma:
                report_line_error('Trailing comma in identifier list {... ,}.')
            if need_rbrace:
                report_line_error('Missing right brace in identifier list {....')
            continue
        if lex_tokens[0].isRBrace():
            report_line_error('Extraneous right brace.')
            del lex_tokens[0]
            continue
        if lex_tokens[0].isEqual():
            p1tokens.append(lex_tokens[0])
            del lex_tokens[0]
            # Special handling for transitions: convert identifier to
            # labelref if followed by ":" or to constructor call with
            # no arguments if followed by "=>". Otherwise just
            # continue and we'll process "identifier (" on the next
            # iteration.
            if len(lex_tokens) < 2:
                report_line_error('Syntax error in transition near %s' %
                                  (lex_tokens[0] if len(lex_tokens) > 0 else 'end of line'))
                continue
            # Assemble optional label for transition.
            if lex_tokens[0].isIdentifier() and lex_tokens[1].isColon():
                p1tokens.append(LabelDef(lex_tokens[0].name))
                del lex_tokens[0:2]
            if len(lex_tokens) < 2:
                report_line_error('Syntax error in transition near %s' %
                                  (lex_tokens[0] if len(lex_tokens) > 0 else 'end of line'))
                continue
            # For transitions, an identifier with no arglist is still a constructor call.
            if lex_tokens[0].isIdentifier():
                if len(lex_tokens) >= 2 and not lex_tokens[1].isArglist():
                    p1tokens.append(ConstructorCall(lex_tokens[0].name,'()'))
                    del lex_tokens[0]
                continue
        # Default: just pass the item (arrowhead, newline) on to the next stage
        if lex_tokens[0].isNewline(): handle_newline()
        p1tokens.append(lex_tokens[0])
        del lex_tokens[0]
    return p1tokens

def parser2(p1tokens):
    """Create a node definition with label, or a transition with label
    and constructor call; no sources/destinations yet."""
    p2tokens = []
    while p1tokens:
        if p1tokens[0].isNewline():
            handle_newline()
            p2tokens.append(p1tokens.pop(0))
            continue
        # Must begin with a node reference or definition.
        if p1tokens[0].isLabelDef():
            label = p1tokens[0].label
            # labeled constructor call
            if p1tokens[1].isConstructorCall():
                call = p1tokens[1]
                p2tokens.append(NodeDefinition(label, call.name, call.arglist))
                del p1tokens[0:2]
                continue
            else:
                if p1tokens[1].isLabelRef() and p1tokens[1].label[0].isupper():
                    hint = "\n\tDid you mean '%s()' ?" % p1tokens[1].label
                else:
                    hint = ""
                report_line_error("Label '%s:' should be followed by a node definition, not %s.%s"
                                  % (label, p1tokens[1], hint))
                del p1tokens[0]
                continue
        if p1tokens[0].isConstructorCall():
            # Unlabeled constructor call: label it.
            call = p1tokens[0]
            label = gen_name(call.name)
            p2tokens.append(NodeDefinition(label, call.name, call.arglist))
            del p1tokens[0]
            continue
        if p1tokens[0].isEqual():   # start of a transition
            del p1tokens[0]
            label = None
            trans = None
            # look for optional transition label
            if p1tokens[0].isLabelDef():
                label = p1tokens[0].label
                del p1tokens[0]  # labeldef
            # look for transition constructor
            if p1tokens[0].isConstructorCall():
                trans_type = p1tokens[0].name
                trans_args = p1tokens[0].arglist
            else:
                report_line_error('Ill-formed transition: should not see %s here.' % p1tokens[0])
                del p1tokens[0]
                continue
            del p1tokens[0]   # constructor
            if not p1tokens[0].isArrowHead():
                report_line_error("Error in transition: expected '=>' not %s." % p1tokens[0])
            del p1tokens[0]  # arrowhead
            trans_class = transition_names.get(trans_type,trans_type)
            if not label:
                label = gen_name(trans_class)
            p2tokens.append(Transition(label,trans_class,trans_args))
            continue
        # Pass along an identifier list without modification
        if p1tokens[0].isIdentifierList() or p1tokens[0].isLabelRef():
            p2tokens.append(p1tokens[0])
            del p1tokens[0]
            continue
        else:
            report_line_error("A %s token is not legal in this context." % p1tokens[0])
            del p1tokens[0]
            continue
    return p2tokens

transition_names = dict(
    N = 'NullTrans',
    T = 'TimerTrans',
    C = 'CompletionTrans',
    S = 'SuccessTrans',
    F = 'FailureTrans',
    D = 'DataTrans',
    TM = 'TextMsgTrans',
    RND = 'RandomTrans',
    PILOT = 'PilotTrans',
    Tap = 'TapTrans',
    ObsMot = 'ObservedMotionTrans',
    UnexMov = 'UnexpectedMovementTrans',
    Aruco = 'ArucoTrans',
    Next = 'NextTrans',
    CNext = 'CNextTrans',
    SayData = 'SayDataTrans',
    Hear = 'HearTrans'
    )

def gen_name(base_name, name_counts=dict()):
    name = base_name.lower()
    if name.startswith('self.'):
        name = name[5:]
    count = name_counts.get(name,0) + 1
    name_counts[name] = count
    return name + repr(count)

def parser3(p2tokens):
    """Chain nodes and transitions by filling in source/destination fields."""
    current_node = None
    need_destination = False
    p3tokens = []
    must_transition = False
    while p2tokens:
        while p2tokens and p2tokens[0].isNewline():
            must_transition = False
            handle_newline()
            del p2tokens[0]
        if not p2tokens: break
        if p2tokens[0].isLabelRef(): 
            must_transition = True
            current_node = [p2tokens[0].label]
            del p2tokens[0]
        elif p2tokens[0].isNodeDefinition():
            must_transition = True
            current_node = [p2tokens[0].label]
            p3tokens.append(p2tokens[0])
            del p2tokens[0]
        elif p2tokens[0].isIdentifierList():
            must_transition = True
            current_node = p2tokens[0].label_refs
            del p2tokens[0]
        elif not current_node:
            report_line_error('Node reference expected before this transition: %s' % p2tokens[0])
        # node definition could be followed by newlines
        while p2tokens and p2tokens[0].isNewline():
            must_transition = False
            handle_newline()
            del p2tokens[0]
        if not p2tokens: break
        # next item must be a transition
        if p2tokens[0].isTransition():
            # check for source
            if not current_node:
                report_line_error('Transition %s has no source nodes.' % p2tokens[0].label)
            p2tokens[0].sources = current_node
            need_destination = True
            p3tokens.append(p2tokens[0])
            del p2tokens[0]
        elif must_transition:
            report_line_error("Expected a transition after '%s', not %s." %
                              (','.join(current_node), p2tokens[0]))
            del p2tokens[0]
            continue
        while p2tokens and p2tokens[0].isNewline():
            handle_newline()
            del p2tokens[0]
        if not p2tokens:
            report_line_error('Missing destination for transition %s.' % p3tokens[-1].label)
            continue
        # next item must be a destination for the transition
        if p2tokens[0].isLabelRef():
            current_node = [p2tokens[0].label]
            del p2tokens[0]
            if need_destination:
                if p3tokens[-1].isTransition():
                    p3tokens[-1].destinations = current_node
                    need_destination = False
            continue
        elif p2tokens[0].isNodeDefinition():
            current_node = [p2tokens[0].label]
            p3tokens[-1].destinations = current_node
            continue  # process the node defintion on the next iteration
        elif p2tokens[0].isIdentifierList():
            current_node = p2tokens[0].label_refs
            del p2tokens[0]
            if need_destination:
                if p3tokens[-1].isTransition():
                    p3tokens[-1].destinations = current_node
                    need_destination = False
        else:
            raise Exception('parser3 is confused by %s.' % p2tokens)
    return p3tokens

def generate_machine(lines):
    global indent_level, current_line, found_error
    found_error = False
    current_line = starting_line
    tok = lexer(''.join(lines))
    if found_error: return
    current_line = starting_line
    p1tokens = parser1(tok)
    if found_error: return
    current_line = starting_line
    p2tokens = parser2(p1tokens)
    if found_error: return
    current_line = starting_line
    p3tokens = parser3(p2tokens)
    if found_error: return

    labels = {}
    for item in p3tokens:
        if item.label in labels:
                report_global_error("Label '%s:' is multiply defined." % item.label)
        elif item.isNodeDefinition() or item.isTransition():
            labels[item.label] = item
        else:
            raise Exception("Problem in generate_machine: %s" % item)

    # Check for undefined references
    for item in p3tokens:
        if item.isTransition():
            for ref in item.sources + item.destinations:
                if ref not in labels:
                    hint = (" Should it be %s() ?" % ref) if ref[0].isupper() else ""
                    report_global_error("Label '%s' was referenced but never defined.%s" %
                                        (ref,hint))
                    labels[ref] = None

    # Write out the state machine source as a comment
    emit_line('def setup(self):')
    indent_level += 4
    indent = ' ' * indent_level + "# "
    # indent = "#"
    out_f.write(indent + indent.join(lines))
    emit_line('')
    emit_line('# Code generated by genfsm on %s:' % time.strftime('%c'))
    emit_line('')

    # Generate the nodes, then the transitions
    for item in p3tokens:
        if item.isNodeDefinition():
            emit_line('%s = %s%s .set_name("%s") .set_parent(self)' %
                      (item.label, item.node_type, item.arglist, item.label))
    for item in p3tokens:
        if item.isTransition():
            emit_line('')
            emit_line('%s = %s%s .set_name("%s")' %
                      (item.label, item.trans_type, item.arglist, item.label))
            emit_line('%s .add_sources(%s) .add_destinations(%s)' %
                        (item.label, ','.join(item.sources), ','.join(item.destinations)))

    emit_line('')
    emit_line('return self')
    indent_level -= 4

indent_level = 0

def emit_line(line):
    out_f.write((' '*indent_level) + line + '\n')
            
def process_file():
    global line_cache, current_line, starting_line, indent_level
    line_cache = [None] # dummy line 0
    current_line = 0
    r_setup = re.compile(r'^\s*\$setup\s*((\"\"\")|(\'\'\')|\{)\s*((\#.*)|)$')
    r_indent = re.compile(r'^\s*')
    while True:
        line = in_f.readline()
        if not line: break
        line_cache.append(line)
        current_line += 1
        # Echo lines to the output file until we reach a $setup line.
        if line.find('$setup') == -1:
            out_f.write(line)
            continue
        setup_match = r_setup.match(line)
        if not setup_match:
            report_line_error("Incorrect $setup syntax: '%s'" % line.strip())
            continue
        delim = setup_match.group(1)[0]
        if delim == '{':
            close_delim = '}'
            r_end = re.compile(r'\s*\}\s*$')
        else:
            close_delim = delim * 3
            r_end = re.compile(r'^\s*' + close_delim)

        # Collect the lines of the state machine.
        starting_line = current_line + 1
        indent_level = r_indent.match(line).span()[1]
        lines = []
        while True:
            line = in_f.readline()
            if not line:
                report_line_error("State machine at line %s ended without closing %s." %
                                  (starting_line-1, close_delim))
                return
            current_line += 1
            line_cache.append(line)
            if r_end.match(line): break
            lines.append(line)
        # Now parse the collected lines and generate code.
        generate_machine(lines)

found_error = False

def report_line_error(error_text):
    global found_error
    cprint(line_cache[current_line].rstrip(), color='red', file=sys.stderr)
    cprint('Line %d: %s\n' % (current_line, error_text), color='red', file=sys.stderr)
    found_error = True

def report_global_error(error_text):
    global found_error
    cprint('Error: %s\n' % error_text, color='red', file=sys.stderr)
    found_error = True

def error_fragment(string):
    s = string.strip()
    p = s.find('\n')
    if p == -1:
        p = len(s)
    fragment = s[0:min(p,20)]
    if len(fragment) < p:
        fragment += "..."
    return fragment

if __name__ == '__main__':
    if len(sys.argv) < 2 or len(sys.argv) > 3:
        print('Usage: genfsm [infile.fsm | -] [outfile.py | -]')
        sys.exit(0)

    infile_name = sys.argv[1]
    if len(sys.argv) == 3:
        outfile_name = sys.argv[2]
    elif infile_name == '-':
        outfile_name = '-'
    else:
        outfile_name = infile_name[0:infile_name.rfind('.')] + ".py"
        if infile_name == outfile_name:
            print("Output file name can't be the same as input file.\nDid you mean %s ?" %
                  (infile_name[0:infile_name.rfind('.')] + ".fsm"))
            sys.exit(1)

    try:
        with (open(infile_name) if infile_name != '-' else sys.stdin) as in_f:
            try:
                with (open(outfile_name,'w') if outfile_name != '-' else sys.stdout) as out_f:
                    process_file()
                    if not found_error:
                        cprint('Wrote generated code to %s.' %
                               (outfile_name if outfile_name != '-' else 'standard output'),
                               color='green')
            except Exception as e:
                print('Error opening output file: %s' % e)
                import traceback
                traceback.print_exc()
                sys.exit(1)
    except Exception as e:
        print('Error opening input file: %s' % e)
        sys.exit(1)
    sys.exit(0)
